{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Set up the environment, and use cuda if present."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import torch, json, numpy\n",
    "from netdissect import proggan, nethook, easydict, zdataset\n",
    "from netdissect.plotutil import plot_tensor_images, plot_max_heatmap\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load a generator model, and instrument a layer for modification."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = proggan.from_pth_file('models/karras/diningroom_lsun.pth').to(device)\n",
    "nethook.retain_layers(model, ['layer4'])\n",
    "nethook.edit_layers(model, ['layer4'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the dissection, find the highest ranked tree units."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [],
   "source": [
    "dissect = easydict.load_json('dissect/diningroom/dissect.json')\n",
    "lrec = next(x for x in dissect.layers if x.layer == 'layer4')\n",
    "rrec = next(x for x in lrec.rankings if x.name == 'table-iou')\n",
    "ct_units = torch.from_numpy(numpy.argsort(rrec.score)[:20])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Generate 20 example images."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [],
   "source": [
    "zbatch = zdataset.z_sample_for_model(model, 30)[...].to(device)\n",
    "base_images = model(zbatch)\n",
    "plot_tensor_images(base_images)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define a function to permute the values of the selected units, and generate the resulting images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_mixes(model, layer, units, zbatch, mixcount=5):\n",
    "    model.ablation[layer] = None\n",
    "    model.replacement[layer] = None\n",
    "    base_images = model(zbatch)\n",
    "    base_features = model.retained[layer]\n",
    "    result = torch.zeros((base_images.shape[0] * mixcount, ) + base_images.shape[1:])\n",
    "    result[0::mixcount] = base_images\n",
    "    for i in range(1, mixcount):\n",
    "        shuf = torch.from_numpy(numpy.random.permutation(len(units)))\n",
    "        new_base_features = base_features[:, units][:, shuf]\n",
    "        replacement = base_features.clone()\n",
    "        replacement[:,units] = new_base_features\n",
    "        ablation = torch.zeros(base_features.shape[1])\n",
    "        ablation.scatter_(0, units, 1)\n",
    "        model.ablation[layer] = ablation\n",
    "        model.replacement[layer] = replacement\n",
    "        result[i::mixcount] = model(zbatch)\n",
    "        model.ablation[layer] = None\n",
    "        model.replacement[layer] = None\n",
    "    return result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Call the function to shuffle dining room tables"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [],
   "source": [
    "zbatch = zdataset.z_sample_for_model(model, 30)[[10,15,16,25]].to(device)\n",
    "plot_tensor_images(make_mixes(model, 'layer4', ct_units, zbatch))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "rrec = next(x for x in lrec.rankings if x.name == 'door-iou')\n",
    "door_units = torch.from_numpy(numpy.argsort(rrec.score)[:20])\n",
    "\n",
    "zbatch = zdataset.z_sample_for_model(model, 50)[[13,29,34,42]].to(device)\n",
    "plot_tensor_images(make_mixes(model, 'layer4', door_units, zbatch))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = proggan.from_pth_file('models/karras/churchoutdoor_lsun.pth').to(device)\n",
    "nethook.retain_layers(model, ['layer4'])\n",
    "nethook.edit_layers(model, ['layer4'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "zbatch = zdataset.z_sample_for_model(model, 50)[...].to(device)\n",
    "base_images = model(zbatch)\n",
    "plot_tensor_images(base_images)\n",
    "\n",
    "features = model.retained['layer4']\n",
    "specific_features = features[:,door_units]\n",
    "plot_max_heatmap(specific_features, shape=(256,256))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}